import time
import pandas as pd
from experiments.utils import load_dataset_safely, seed_everything, safe_import


def h2o_automl_baseline(dataset_name: str, max_runtime_secs: int = 300, random_state: int = 42):
    """Run H2O AutoML on a dataset and return validation accuracy & runtime."""
    seed_everything(random_state)

    h2o_mod, err = safe_import("h2o")
    if h2o_mod is None:
        raise RuntimeError(f"H2O not installed: {err}")

    try:
        from h2o.automl import H2OAutoML  # type: ignore
    except Exception as e:
        raise RuntimeError(f"Failed to import H2OAutoML: {e}")

    data, msg = load_dataset_safely(dataset_name)
    if data is None:
        raise RuntimeError(msg)

    X_train, y_train = data["X_train"], data["y_train"]
    X_val, y_val = data["X_val"], data["y_val"]

    # Ensure pandas DataFrame/Series
    if not hasattr(X_train, 'columns'):
        X_train = pd.DataFrame(X_train)
    if not hasattr(X_val, 'columns'):
        X_val = pd.DataFrame(X_val)

    start = time.time()
    score = 0.0
    cluster_started = False

    try:
        # Initialize H2O cluster
        h2o_mod.init(max_mem_size="4G", nthreads=-1)
        cluster_started = True

        # Prepare train & validation frames
        train_df = X_train.copy()
        train_df['target'] = y_train
        val_df = X_val.copy()
        val_df['target'] = y_val

        train_h2o = h2o_mod.H2OFrame(train_df)
        val_h2o = h2o_mod.H2OFrame(val_df)

        # Force target as categorical for classification
        train_h2o['target'] = train_h2o['target'].asfactor()
        val_h2o['target'] = val_h2o['target'].asfactor()

        features = [c for c in train_h2o.columns if c != 'target']

        automl = H2OAutoML(
            max_runtime_secs=max_runtime_secs,
            seed=random_state,
            sort_metric="AUTO",
            verbosity=None,
            max_models=10,          # limit for small datasets
            exclude_algos=["DeepLearning"]  # avoid overfitting tiny datasets
        )
        automl.train(x=features, y='target', training_frame=train_h2o)

        # Predictions
        preds = automl.leader.predict(val_h2o)
        pred_df = preds.as_data_frame(use_multi_thread=True)
        pred_labels = pred_df['predict'].values

        # Compare as strings for accuracy
        score = float((pred_labels.astype(str) == pd.Series(y_val).astype(str).values).mean())

    except Exception as e:
        print("H2O AutoML failed:", e)
        score = 0.0
    finally:
        duration = time.time() - start
        # Cleanup
        if cluster_started:
            try:
                h2o_mod.remove_all()
                h2o_mod.shutdown(prompt=False)
            except Exception:
                pass

    return {
        "val_score": score,
        "time_sec": duration,
    }


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="iris")
    p.add_argument("--time", type=int, default=300, help="Max runtime seconds for H2O AutoML")
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()
    res = h2o_automl_baseline(args.dataset, args.time, args.seed)
    print(res)
